{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Authoring Leaf Systems\n",
    "For instructions on how to run these tutorial notebooks, please see the [index](./index.ipynb).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pydrake.common.containers import namedview\n",
    "from pydrake.common.value import Value\n",
    "from pydrake.math import RigidTransform, RotationMatrix\n",
    "from pydrake.systems.analysis import Simulator\n",
    "from pydrake.systems.framework import BasicVector, LeafSystem\n",
    "from pydrake.trajectories import PiecewisePolynomial"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "The [Modeling Dynamical Systems](./dynamical_systems.ipynb) tutorial gave a very basic introduction to Drake's Systems framework, including writing a basic [LeafSystem](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_leaf_system.html). In this notebook we'll provide a more advanced/complete overview of authoring those systems."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input and Output Ports\n",
    "\n",
    "Leaf systems can have any number of input and output ports.  There are two basic types of ports: **vector-valued** ports and **abstract-valued** ports. \n",
    "\n",
    "### Vector-valued Ports\n",
    "Let's start with vector-valued ports, which are the easiest to work with.  The following system declares two vector input ports, and two vector output ports. It forms the sum and difference of the two 2-element input vectors and places the results on the two output ports.\n",
    "\n",
    "<table align=center cellpadding=0 cellspacing=0><tr align=center style=\"border:none;\"><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=right style=\"padding:5px 0px 5px 0px; border:none;\">a &rarr;</td></tr><tr style=\"border:none;\"><td align=right style=\"padding:5px 0px 5px 0px; border:none;\">b &rarr;</td></tr></table></td><td align=center style=\"border:2px solid black;padding-left:20px;padding-right:20px;vertical-align:middle;\" bgcolor=#F0F0F0>MyAdder</td><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; sum</td></tr><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; difference</td></tr></table></td></tr></table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyAdder(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()  # Don't forget to initialize the base class.\n",
    "        self._a_port = self.DeclareVectorInputPort(name=\"a\", size=2)\n",
    "        self._b_port = self.DeclareVectorInputPort(name=\"b\", size=2)\n",
    "        self.DeclareVectorOutputPort(name=\"sum\", size=2, calc=self.CalcSum)\n",
    "        self.DeclareVectorOutputPort(name=\"difference\",\n",
    "                                     size=2,\n",
    "                                     calc=self.CalcDifference)\n",
    "\n",
    "    def CalcSum(self, context, output):\n",
    "        # Evaluate the input ports to obtain the 2x1 vectors.\n",
    "        a = self._a_port.Eval(context)\n",
    "        b = self._b_port.Eval(context)\n",
    "\n",
    "        # Write the sum into the output vector.\n",
    "        output.SetFromVector(a + b)\n",
    "\n",
    "    def CalcDifference(self, context, output):\n",
    "        # Evaluate the input ports to obtain the 2x1 vectors.\n",
    "        a = self._a_port.Eval(context)\n",
    "        b = self._b_port.Eval(context)\n",
    "\n",
    "        # Write the difference into output vector.\n",
    "        output.SetFromVector(a - b)\n",
    "\n",
    "# Construct an instance of this system and a context.\n",
    "system = MyAdder()\n",
    "context = system.CreateDefaultContext()\n",
    "\n",
    "# Fix the input ports to some constant values.\n",
    "system.GetInputPort(\"a\").FixValue(context, [3, 4])\n",
    "system.GetInputPort(\"b\").FixValue(context, [1, 2])\n",
    "\n",
    "# Evaluate the output ports.\n",
    "print(f\"sum: {system.GetOutputPort('sum').Eval(context)}\")\n",
    "print(f\"difference: {system.GetOutputPort('difference').Eval(context)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are a few things to notice here:\n",
    "- You can declare as many inputs or outputs as you like (and they need not be the same size).\n",
    "- Any of the inputs can be evaluated via their `Eval()` method from any of the system methods that accept a `Context`.\n",
    "- For each output port, you define a different callback function that implements the output.\n",
    "- Keeping track of the order of the ports can be a bit brittle. So we can improve the code, e.g., by using their index or referencing them by name.\n",
    "\n",
    "It is also helpful to understand that:\n",
    "- You can also make inputs optional, by checking [InputPort::HasValue()](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_input_port.html#a5536b94a4642fa4cf47164437dc66ae8) in the system method implementations.\n",
    "- Input and output ports do not have any timing semantics on their own; output ports are evaluated whenever their output is requested, and [use caching](https://drake.mit.edu/doxygen_cxx/group__cache__design__notes.html) by default to avoid repeated computation.\n",
    "- The `DeclareVectorOutputPort()` and `DeclareAbstractOutputPort()` ports take an optional `prerequisites_of_calc` argument which is used by the caching system to determine when the output port needs to be re-evaluated. Declaring narrow prerequisites can speed up your code. Surprisingly, explicitly narrowing the prerequisites can also dramatically speed up the construction of some `Diagram`s, because when these are not specified the `DiagramBuilder` attempts to check for algebraic loops by converting your systems into symbolic form.\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Abstract-valued Ports\n",
    "Not all inputs and outputs are best represented as a vector. Drake's systems framework also supports passing structured types across ports, which is implemented in C++ using a technique called \"type erasure\". From the perspective of the systems framework, classes work with an \"abstract\" data type, and only the system implementation itself needs to know how to work with the structured type. In practice, if you can overlook a small amount of boilerplate, the usage is straightforward. Let's give an example of a `LeafSystem` that uses a `RigidTransform` in its input and output ports:\n",
    "\n",
    "<table align=center cellpadding=0 cellspacing=0><tr align=center><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=right style=\"padding:5px 0px 5px 0px; border:none;\">in &rarr;</td></tr></table></td><td align=center style=\"border:2px solid black;padding-left:20px;padding-right:20px;vertical-align:middle\" bgcolor=#F0F0F0>RotateAboutZ</td><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; out</td></tr></table></td></tr></table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RotateAboutZ(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()  # Don't forget to initialize the base class.\n",
    "        self.DeclareAbstractInputPort(name=\"in\",\n",
    "                                      model_value=Value(RigidTransform()))\n",
    "        self.DeclareAbstractOutputPort(\n",
    "            name=\"out\",\n",
    "            alloc=lambda: Value(RigidTransform()),\n",
    "            calc=self.CalcOutput)\n",
    "\n",
    "    def CalcOutput(self, context, output):\n",
    "        # Evaluate the input port to obtain the RigidTransform.\n",
    "        X_1 = system.get_input_port().Eval(context)\n",
    "\n",
    "        X_2 = RigidTransform(RotationMatrix.MakeZRotation(np.pi / 2)) @ X_1\n",
    "\n",
    "        # Set the output RigidTransform.\n",
    "        output.set_value(X_2)\n",
    "\n",
    "# Construct an instance of this system and a context.\n",
    "system = RotateAboutZ()\n",
    "context = system.CreateDefaultContext()\n",
    "\n",
    "# Fix the input port to a constant value.\n",
    "system.get_input_port().FixValue(context, RigidTransform())\n",
    "\n",
    "# Evaluate the output port.\n",
    "print(f\"output: {system.get_output_port(0).Eval(context)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State Variables\n",
    "\n",
    "Systems store their state in the `Context`, and work with state in a very similar way to input and output ports. Again, we can have vector-valued state or abstract-valued state. Abstract state can only be updated in a discrete-time fashion, but vector-valued state can be declared to be either **discrete** or **continuous**.  \n",
    "\n",
    "A discrete state $x_d$ will evolve based on update events, the most common of which would be a simple periodic event to define a difference equation.  We saw an example of this in the [introductory notebook](./dynamical_systems.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleDiscreteTimeSystem(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        state_index = self.DeclareDiscreteState(1)  # One state variable.\n",
    "        self.DeclareStateOutputPort(\"y\", state_index)  # One output: y=x.\n",
    "        self.DeclarePeriodicDiscreteUpdateEvent(\n",
    "            period_sec=1.0,  # One second time step.\n",
    "            offset_sec=0.0,  # The first event is at time zero.\n",
    "            update=self.Update) # Call the Update method defined below.\n",
    "\n",
    "    # x[n+1] = x^3[n].\n",
    "    def Update(self, context, discrete_state):\n",
    "        x = context.get_discrete_state_vector().GetAtIndex(0)\n",
    "        x_next = x**3\n",
    "        discrete_state.get_mutable_vector().SetAtIndex(0, x_next)\n",
    "\n",
    "# Instantiate the System.\n",
    "system = SimpleDiscreteTimeSystem()\n",
    "simulator = Simulator(system)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "# Set the initial conditions: x[0] = [0.9].\n",
    "context.get_mutable_discrete_state_vector().SetFromVector([0.9])\n",
    "\n",
    "# Run the simulation.\n",
    "simulator.AdvanceTo(4.0)\n",
    "print(context.get_discrete_state_vector())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can find details on the implementation and order of these events [here](https://drake.mit.edu/doxygen_cxx/group__discrete__systems.html). It's also possible to define more complicated [events](https://drake.mit.edu/doxygen_cxx/group__events__description.html) for updating the discrete variables. \n",
    "\n",
    "Note that the states of a system are not immediately accessible to other systems in a `Diagram`. For a system to share its state, it must do that using an output port. We provide the `DeclareStateOutputPort()` method to make this common case easy.\n",
    "\n",
    "The evolution of continuous vector-valued state variables is governed by a differential equation. While it is possible and convenient to declare many different groups of discrete state variables (they may even be updated at different rates by different events), we only define zero or one continuous state vector for a system, and define its dynamics by overloading the `LeafSystem::DoCalcTimeDerivatives()` method. Here is the simple continuous-time example in the [introductory notebook](./dynamical_systems.ipynb) again: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the system.\n",
    "class SimpleContinuousTimeSystem(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        state_index = self.DeclareContinuousState(1)  # One state variable.\n",
    "        self.DeclareStateOutputPort(\"y\", state_index)  # One output: y=x.\n",
    "\n",
    "    # xdot(t) = -x(t) + x^3(t).\n",
    "    def DoCalcTimeDerivatives(self, context, derivatives):\n",
    "        x = context.get_continuous_state_vector().GetAtIndex(0)\n",
    "        xdot = -x + x**3\n",
    "        derivatives.get_mutable_vector().SetAtIndex(0, xdot)\n",
    "\n",
    "# Instantiate the System.\n",
    "system = SimpleContinuousTimeSystem()\n",
    "simulator = Simulator(system)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "# Set the initial conditions: x(0) = [0.9]\n",
    "context.SetContinuousState([0.9])\n",
    "\n",
    "# Run the simulation.\n",
    "simulator.AdvanceTo(4.0)\n",
    "print(context.get_continuous_state_vector())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For abstract-valued state, we have a similar workflow. We declare the abstract state in the constructor, and define update events to update that state.  Discrete update events are restricted to updating *only* the discrete states, so we use \"unrestricted\" events to update the abstract state. Here is an example which stores a simple `PiecewisePolynomial` trajectory as abstract state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AbstractStateSystem(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self._traj_index = self.DeclareAbstractState(\n",
    "            Value(PiecewisePolynomial()))\n",
    "        self.DeclarePeriodicUnrestrictedUpdateEvent(period_sec=1.0,\n",
    "                                                    offset_sec=0.0,\n",
    "                                                    update=self.Update)\n",
    "\n",
    "    def Update(self, context, state):\n",
    "        t = context.get_time()\n",
    "        traj = PiecewisePolynomial.FirstOrderHold(\n",
    "            [t, t + 1],\n",
    "            np.array([[-np.pi / 2.0 + 1., -np.pi / 2.0 - 1.], [-2., 2.]]))\n",
    "        # Update the state\n",
    "        state.get_mutable_abstract_state(int(self._traj_index)).set_value(traj)\n",
    "\n",
    "\n",
    "\n",
    "system = AbstractStateSystem()\n",
    "simulator = Simulator(system)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "# Set an initial condition for the abstract state.\n",
    "context.SetAbstractState(0, PiecewisePolynomial())\n",
    "\n",
    "# Run the simulation.\n",
    "simulator.AdvanceTo(4.0)\n",
    "traj = context.get_abstract_state(0).get_value()\n",
    "print(f\"breaks: {traj.get_segment_times()}\")\n",
    "print(f\"traj({context.get_time()}) = {traj.value(context.get_time())}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters\n",
    "\n",
    "Parameters are like state, except that they are constant through the lifetime of a simulation. Parameters in the systems framework are declared and accessed almost identically to state, but they are never updated. Once again, we have vector-valued (declared with `DeclareNumericParameter`) and abstract-valued (via `DeclareAbstractParameter`) parameters.\n",
    "\n",
    "<table align=center cellpadding=0 cellspacing=0><tr align=center><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0></table></td><td align=center style=\"border:2px solid black;padding-left:20px;padding-right:20px;vertical-align:middle\" bgcolor=#F0F0F0>SystemWithParameters</td><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; numeric</td></tr><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; abstract</td></tr></table></td></tr></table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SystemWithParameters(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()  # Don't forget to initialize the base class.\n",
    "\n",
    "        self.DeclareNumericParameter(BasicVector([1.2, 3.4]))\n",
    "        self.DeclareAbstractParameter(\n",
    "            Value(RigidTransform(RotationMatrix.MakeXRotation(np.pi / 6))))\n",
    "\n",
    "        # Declare output ports to demonstrate how to access the parameters in\n",
    "        # system methods.\n",
    "        self.DeclareVectorOutputPort(name=\"numeric\",\n",
    "                                     size=2,\n",
    "                                     calc=self.OutputNumeric)\n",
    "        self.DeclareAbstractOutputPort(\n",
    "            name=\"abstract\",\n",
    "            alloc=lambda: Value(RigidTransform()),\n",
    "            calc=self.OutputAbstract)\n",
    "\n",
    "    def OutputNumeric(self, context, output):\n",
    "        output.SetFromVector(context.get_numeric_parameter(0).get_value())\n",
    "\n",
    "    def OutputAbstract(self, context, output):\n",
    "        output.set_value(context.get_abstract_parameter(0).get_value())\n",
    "\n",
    "# Construct an instance of this system and a context.\n",
    "system = SystemWithParameters()\n",
    "context = system.CreateDefaultContext()\n",
    "\n",
    "# Evaluate the output ports.\n",
    "print(f\"numeric: {system.get_output_port(0).Eval(context)}\")\n",
    "print(f\"abstract: {system.get_output_port(1).Eval(context)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Systems can \"publish\"\n",
    "\n",
    "In addition to methods that update the state, and methods that evaluate an output port, another supported method in a `LeafSystem` is a callback to \"publish\". A \"publish\" method cannot modify any state: they are useful for broadcasting data outside of the systems framework (e.g. via a message-passing protocol like ROS), for terminating a simulation, for detecting errors, and for forcing boundaries between integration steps. They are declared in much the same way as other system callbacks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyPublishingSystem(LeafSystem):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        # Calling `ForcePublish()` will trigger the callback.\n",
    "        self.DeclareForcedPublishEvent(self.Publish)\n",
    "\n",
    "        # Publish once every second.\n",
    "        self.DeclarePeriodicPublishEvent(period_sec=1,\n",
    "                                         offset_sec=0,\n",
    "                                         publish=self.Publish)\n",
    "        \n",
    "    def Publish(self, context):\n",
    "        print(f\"Publish() called at time={context.get_time()}\")\n",
    "\n",
    "system = MyPublishingSystem()\n",
    "simulator = Simulator(system)\n",
    "simulator.AdvanceTo(5.3)\n",
    "\n",
    "# We can also \"force\" a publish at a arbitrary time.\n",
    "print(\"\\ncalling ForcedPublish:\")\n",
    "system.ForcedPublish(simulator.get_context())\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Naming Vector Values\n",
    "\n",
    "Abstract-valued ports, state, and/or parameters can be used to work with structured data. But it can also be convenient to use string names to reference the individual elements of vector-valued data. In Python, we recommend using a [`namedview`](https://drake.mit.edu/pydrake/pydrake.common.containers.html?highlight=namedview#pydrake.common.containers.namedview) workflow, which was inspired by Python's [`namedtuple`](https://docs.python.org/3.6/library/collections.html?highlight=namedtuple#collections.namedtuple) and is demonstrated in the example below. The Drake developers plan to [provide similar functionality in C++](https://github.com/RobotLocomotion/drake/issues/12566).\n",
    "\n",
    "Note that the memory layout for a continuous state vector is slightly different than for discrete state, parameters, and input/output port data, so the syntax for dealing with it requires a slightly distinct `CopyToVector()` and `SetFromVector()`.  See [issue #9171](https://github.com/RobotLocomotion/drake/issues/9171).\n",
    "\n",
    "Note that usages of `my_view[:]` are meant as a shorthand to get a NumPy array view of the `namedview` itself.\n",
    "For more details, see the example and warning in the documentation for [`namedview`](https://drake.mit.edu/pydrake/pydrake.common.containers.html?highlight=namedview#pydrake.common.containers.namedview)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the system.\n",
    "class NamedViewDemo(LeafSystem):\n",
    "    MyDiscreteState = namedview(\"MyDiscreteState\", [\"a\", \"b\"])\n",
    "    MyContinuousState = namedview(\"MyContinuousState\", [\"x\", \"z\", \"theta\"])\n",
    "    MyOutput = namedview(\"MyOutput\", [\"x\",\"a\"])\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.DeclareDiscreteState(2)\n",
    "        self.DeclarePeriodicDiscreteUpdateEvent(\n",
    "            period_sec=1.0,\n",
    "            offset_sec=0.0,\n",
    "            update=self.DiscreteUpdate)\n",
    "        self.DeclareContinuousState(3)\n",
    "        self.DeclareVectorOutputPort(name=\"out\", size=2, calc=self.CalcOutput)\n",
    "\n",
    "    def DiscreteUpdate(self, context, discrete_values):\n",
    "        discrete_state = self.MyDiscreteState(\n",
    "            context.get_discrete_state_vector().value())\n",
    "        continuous_state = self.MyContinuousState(\n",
    "            context.get_continuous_state_vector().CopyToVector())\n",
    "        next_state = self.MyDiscreteState(discrete_values.get_mutable_value())\n",
    "        # Now we can compute the next state by referencing each element by name.\n",
    "        next_state.a = discrete_state.a + 1\n",
    "        next_state.b = discrete_state.b + continuous_state.x\n",
    "\n",
    "    def DoCalcTimeDerivatives(self, context, derivatives):\n",
    "        continuous_state = self.MyContinuousState(\n",
    "            context.get_continuous_state_vector().CopyToVector())\n",
    "        dstate_dt = self.MyContinuousState(continuous_state[:])\n",
    "        dstate_dt.x = -continuous_state.x\n",
    "        dstate_dt.z = -continuous_state.z\n",
    "        dstate_dt.theta = -np.arctan2(continuous_state.z, continuous_state.x)\n",
    "        derivatives.SetFromVector(dstate_dt[:])\n",
    "\n",
    "    def CalcOutput(self, context, output):\n",
    "        discrete_state = self.MyDiscreteState(\n",
    "            context.get_discrete_state_vector().value())\n",
    "        continuous_state = self.MyContinuousState(\n",
    "            context.get_continuous_state_vector().CopyToVector())\n",
    "        out = self.MyOutput(output.get_mutable_value())\n",
    "        out.x = continuous_state.x\n",
    "        out.a = discrete_state.a\n",
    "\n",
    "# Instantiate the System.\n",
    "system = NamedViewDemo()\n",
    "simulator = Simulator(system)\n",
    "context = simulator.get_mutable_context()\n",
    "\n",
    "# Set the initial conditions.\n",
    "initial_discrete_state = NamedViewDemo.MyDiscreteState([3, 4])\n",
    "context.SetDiscreteState(initial_discrete_state[:])\n",
    "initial_continuous_state = NamedViewDemo.MyContinuousState.Zero()\n",
    "initial_continuous_state.x = 0.5\n",
    "initial_continuous_state.z = 0.92\n",
    "initial_continuous_state.theta = 0.23\n",
    "context.SetContinuousState(initial_continuous_state[:])\n",
    "\n",
    "# Run the simulation.\n",
    "simulator.AdvanceTo(4.0)\n",
    "print(\n",
    "    NamedViewDemo.MyDiscreteState(context.get_discrete_state_vector().value()))\n",
    "print(\n",
    "    NamedViewDemo.MyContinuousState(\n",
    "        context.get_continuous_state_vector().CopyToVector()))\n",
    "print(NamedViewDemo.MyOutput(system.get_output_port().Eval(context)))\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Supporting Scalar Type Conversion (double, AutoDiff, and Symbolic)\n",
    "\n",
    "In order to support AutoDiff and Symbolic types throughout the systems framework, LeafSystems can be written to support \"scalar type conversion\". In Python, adding this support requires a small amount of boilerplate.  Here is a simple example:\n",
    "\n",
    "<table align=center cellpadding=0 cellspacing=0><tr align=center><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=right style=\"padding:5px 0px 5px 0px; border:none;\">state &rarr;</td></tr><tr><td align=right style=\"padding:5px 0px 5px 0px; border:none;\">command &rarr;</td></tr></table></td><td align=center style=\"border:2px solid black;padding-left:20px;padding-right:20px;vertical-align:middle\" bgcolor=#F0F0F0>RunningCost</td><td style=\"vertical-align:middle; border:none;\"><table cellspacing=0 cellpadding=0><tr><td align=left style=\"padding:5px 0px 5px 0px; border:none;\">&rarr; cost</td></tr></table></td></tr></table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydrake.systems.framework import LeafSystem_\n",
    "from pydrake.systems.scalar_conversion import TemplateSystem\n",
    "from pydrake.autodiffutils import AutoDiffXd\n",
    "from pydrake.symbolic import Expression\n",
    "\n",
    "@TemplateSystem.define(\"RunningCost_\")\n",
    "def RunningCost_(T):\n",
    "\n",
    "    class Impl(LeafSystem_[T]):\n",
    "\n",
    "        def _construct(self, converter=None, Q=np.eye(2)):\n",
    "            super().__init__(converter)\n",
    "            self._Q = Q\n",
    "            self._state_port = self.DeclareVectorInputPort(\"state\", 2)\n",
    "            self._command_port = self.DeclareVectorInputPort(\"command\", 1)\n",
    "            self.DeclareVectorOutputPort(\"cost\", 1, self.CostOutput)\n",
    "\n",
    "        def _construct_copy(self, other, converter=None):\n",
    "            # Any member fields (e.g. configuration values) need to be\n",
    "            # transferred here from `other` to `self`.\n",
    "            Impl._construct(self, converter=converter, Q=other._Q)\n",
    "\n",
    "        def CostOutput(self, context, output):\n",
    "            x = self._state_port.Eval(context)\n",
    "            u = self._command_port.Eval(context)[0]\n",
    "            output[0] = x.dot(self._Q.dot(x)) + u**2\n",
    "\n",
    "    return Impl\n",
    "\n",
    "RunningCost = RunningCost_[None]  # Default instantiation."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The important steps are:\n",
    "- Add the `@TemplateSystem` decorator.\n",
    "- Derive from `LeafSystem_[T]` instead of simply `LeafSystem`.\n",
    "- Implement the `_construct` method *instead* of the typical `__init__` method.\n",
    "- Implement the `_construct_copy` method, which needs to populate the same member fields as `_construct` (as we did with `self.Q` in this example).\n",
    "- Add the default instantiation, so that you can continue to refer to the system as, e.g., `RunningCost` in addition to using `RunningCost_[float]`.\n",
    "\n",
    "For further details, you can find the related documentation for scalar conversion in C++ [here](https://drake.mit.edu/doxygen_cxx/group__system__scalar__conversion.html) and the documentation for the [`@TemplateSystem`](https://drake.mit.edu/pydrake/pydrake.systems.scalar_conversion.html#pydrake.systems.scalar_conversion.TemplateSystem) decorator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Having done this, we can still use the system in the original way:\n",
    "system = RunningCost(Q=np.diag([10, 1]))\n",
    "context = system.CreateDefaultContext()\n",
    "system.get_input_port(0).FixValue(context, [1, 2])\n",
    "system.get_input_port(1).FixValue(context, [3])\n",
    "print(system.get_output_port().Eval(context))\n",
    "\n",
    "# But we can also now use an autodiff or symbolic versions of the system,\n",
    "# either by declaring them directly:\n",
    "system_ad = RunningCost_[AutoDiffXd]()\n",
    "system_symbolic = RunningCost_[Expression]()\n",
    "\n",
    "# or by scalar conversion:\n",
    "system_ad = system.ToAutoDiffXd()\n",
    "system_symbolic = system.ToSymbolic()\n",
    "\n",
    "# We can also convert the time, state, parameters, and (if needed) input ports\n",
    "# from the original system:\n",
    "context_symbolic = system_symbolic.CreateDefaultContext()\n",
    "context_symbolic.SetTimeStateAndParametersFrom(context)\n",
    "system_symbolic.FixInputPortsFrom(system, context, context_symbolic)\n",
    "print(system_symbolic.get_output_port().Eval(context_symbolic))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Further reading\n",
    "\n",
    "[Caching in the systems framework](https://drake.mit.edu/doxygen_cxx/group__cache__design__notes.html)\n",
    "\n",
    "[Declaring events](https://drake.mit.edu/doxygen_cxx/group__events__description.html)\n",
    "\n",
    "[Stochastic systems](https://drake.mit.edu/doxygen_cxx/group__stochastic__systems.html)\n",
    "\n",
    "Declaring [system constraints](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_system_constraint.html): [equality](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_leaf_system.html#a4948ad0241c67045b3c794874b2986a0) and [inequality](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_leaf_system.html#a3bac306621c3f0839324649151c22af2).\n",
    "\n",
    "[Writing continuous dynamics in implicit form](https://drake.mit.edu/doxygen_cxx/classdrake_1_1systems_1_1_system.html#a2bb4c1e3572a8009863b5a342fcb5c49)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
